# from __future__ import annotations

import math

import numpy as np
import torch
import torch.nn as nn
from rich import print
from rich.columns import Columns
from rich.console import Console
from torch import Tensor
import torch.nn.functional as F

console = Console(width=140)

from .dlrt_module import DLRTModule
torch.set_default_dtype(torch.float32)


def DLRTLinear(
        in_features: int,
        out_features: int,
        adaptive: bool = True,
        low_rank_percent: int = None,
        bias: bool = True,
        init_method: str = "random",
        device=None,
        dtype=None,
        eps_adapt: float = 0.01,
        pretrain: bool = False,
):
    """
    Gets a linear layer with the given features
    Args:
    """
    if not adaptive:
        if low_rank_percent is not None:
            rank = int(low_rank_percent * min(in_features, out_features))
        else:
            rank = None
        return DLRTLinearFixed(
            in_features=in_features,
            out_features=out_features,
            bias=bias,
            rank=rank,
            init_method=init_method,
            device=device,
            dtype=dtype,
        )
    else:
        # return DLRTLinearAdaptiveTransposed(
        return DLRTLinearAdaptive(
            in_features=in_features,
            out_features=out_features,
            bias=bias,
            low_rank_percent=low_rank_percent,
            eps_adapt=eps_adapt,
            device=device,
            dtype=dtype,
            pretrain=pretrain,
        )


class DLRTLinearFixed(DLRTModule):  ######## Still no modified TODO: to modify for this new version
    # should this instead inherit from nn.Linear?
    #   doesnt need to, everything in nn.Linear is overwritten
    # overwrite the original layer depending on its type?
    __constants__ = ["in_features", "out_features"]
    in_features: int
    out_features: int
    weight: Tensor

    def __init__(
            self,
            in_features: int,
            out_features: int,
            rank: int = None,
            bias: bool = True,
            init_method: str = "random",
            device=None,
            dtype=None,
    ) -> None:
        """
        TODO: ...this
        Parameters
        ----------
        in_features
        out_features
        bias
        device
        dtype
        low_rank:
            top-rank approx. this will cut to the top-rank eigenvectors
            for the math, this is the inner dim of the decomp
        """

        self.low_rank = rank if rank is not None else [self.out_features, self.in_features]

        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        # self.step = 1
        self.device = device
        self.tau = 0.0
        self.in_features = in_features
        self.out_features = out_features
        if (isinstance(bias, bool) and bias) or bias is not None:
            self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
        else:
            self.register_parameter("bias", None)

        self.basic_number_weights = out_features * in_features

        self.dlrt = True
        ##### added part
        low_rank = self.rank  ########## TODO: take care of fixed rank
        self.dims = [self.out_features, self.in_features]
        self.rmax = [int(d // 2) for d in self.dims]
        self.rank = [min([r, rmax_r]) for r, rmax_r in zip(rank, self.dims)]

        self.init_method = init_method
        assert init_method in ["random", "svd"], "init_method must be in ['random', 'svd']"

        self.C = torch.nn.Parameter(torch.abs(torch.randn(size=[int(2 * s) for s in self.dynamic_rank])).to(device))
        self.Us = torch.nn.ParameterList(
            [torch.nn.Parameter(torch.randn(d, r)).to(device) for d, r in zip(self.dims, self.dynamic_rank)])
        self.Ks = torch.nn.ParameterList(
            [torch.nn.Parameter(torch.randn(d, int(r))).to(device) for (d, r) in zip(self.dims, self.rmax)])
        U_hats_list = [torch.nn.Parameter(torch.randn(d, int(2 * r)).to(device), requires_grad=False)
                       for (d, r) in zip(self.dims, self.rmax)]
        self.U_hats = torch.nn.ParameterList(U_hats_list)

        self.Vst = torch.nn.ParameterList(
            [torch.nn.Parameter(torch.randn(self.rmax[0], self.in_features), requires_grad=False).to(device),
             torch.nn.Parameter(torch.randn(self.rmax[1], self.out_features), requires_grad=False).to(
                 device)])  # already transposed

        self.M_hats = torch.nn.ParameterList(
            [torch.nn.Parameter(torch.randn(int(2 * r), int(r)), requires_grad=False).to(device) for r in self.rmax])

        self.reset_parameters()
        self.train_case = "k"

    def extra_repr(self) -> str:
        return (
            f"in_features={self.in_features}, rank={self.low_rank}, "
            f"out_features={self.out_features}, bias={self.bias is not None}"
        )

    def print_means(self):
        shapes = []
        shapes.append(f"k: {self.k.mean():.4f} {self.k.min():.4f} {self.k.max():.4f} {self.k.requires_grad}")
        shapes.append(
            f"s: {self.s.mean():.4f} {self.s.min():.4f} {self.s.max():.4f} " f"{self.s.requires_grad}",
        )
        shapes.append(
            f"lt: {self.lt.mean():.4f} {self.lt.min():.4f} {self.lt.max():.4f}" f" {self.lt.requires_grad}",
        )
        shapes.append(
            f"u: {self.u.mean():.4f} {self.u.min():.4f} {self.u.max():.4f} " f"{self.u.requires_grad}",
        )
        shapes.append(
            f"unp1: {self.unp1.mean():.4f} {self.unp1.min():.4f} "
            f"{self.unp1.max():.4f} {self.unp1.requires_grad}",
        )
        shapes.append(
            f"vt: {self.vt.mean():.4f} {self.vt.min():.4f} {self.vt.max():.4f}" f" {self.vt.requires_grad}",
        )
        shapes.append(
            f"vtnp1: {self.vtnp1.mean():.4f} {self.vtnp1.min():.4f} "
            f"{self.vtnp1.max():.4f} {self.vtnp1.requires_grad}",
        )
        shapes.append(
            f"n: {self.n.mean():.4f} {self.n.min():.4f} " f"{self.n.max():.4f} {self.n.requires_grad}",
        )
        shapes.append(
            f"m: {self.m.mean():.4f} {self.m.min():.4f} " f"{self.m.max():.4f} {self.m.requires_grad}",
        )
        if self.bias is not None:
            shapes.append(
                f"bias: {self.bias.mean():.4f} {self.bias.min():.4f} "
                f"{self.bias.max():.4f} {self.bias.requires_grad}",
            )
        # if self.rank == 0: # and self.counter % 100 == 0:
        columns = Columns(shapes, equal=True, expand=True)
        # console.rule("All shapes in linear")
        console.print(columns)

    def get_classic_weight_repr(self):
        if self.s.ndim == 1:
            return self.k @ torch.diag(self.s) @ self.lt
        return self.k @ self.s @ self.lt

    @torch.no_grad()
    def reset_parameters(self):

        nn.init.kaiming_uniform_(self.C, a=math.sqrt(5))

        for u, vt, u_hat, k, m in zip(self.Us, self.Vst, self.U_hats, self.Ks, self.M_hats):
            nn.init.kaiming_uniform_(u, a=math.sqrt(5))
            nn.init.kaiming_uniform_(vt, a=math.sqrt(5))
            nn.init.kaiming_uniform_(u_hat, a=math.sqrt(5))
            # nn.init.kaiming_uniform_(self.vtnp1, a=math.sqrt(5))
            nn.init.kaiming_uniform_(k, a=math.sqrt(5))
            nn.init.kaiming_uniform_(m, a=math.sqrt(5))

        if self.bias is not None:
            w = torch.einsum('ijkl,ai,bj,ck,dl-_>abcd', self.C,
                             *self.Us)  # torch.linalg.multi_dot([self.k, self.s, self.lt])
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(w)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input: Tensor) -> Tensor:
        if self.train_case == "k" or not self.training:  # k-step
            ret = torch.linalg.multi_dot([input, self.k, self.vt])
        elif self.train_case == "l":  # l-step
            ret = torch.linalg.multi_dot([input, self.u, self.lt])
            # ret = (input @ self.u) @ self.lt
        else:  # s-step
            ret = torch.linalg.multi_dot([input, self.unp1, self.s, self.vtnp1])
            # ret = ((input @ self.unp1) @ self.s) @ self.vtnp1
        return ret if self.bias is None else ret + self.bias

    def _change_params_requires_grad(self, requires_grad):
        self.u.requires_grad = False  # requires_grad
        self.s.requires_grad = requires_grad
        self.vt.requires_grad = False  # requires_grad
        self.unp1.requires_grad = False  # requires_grad
        self.vtnp1.requires_grad = False  # requires_grad
        self.k.requires_grad = requires_grad
        self.lt.requires_grad = requires_grad
        self.n.requires_grad = False  # requires_grad
        self.m.requires_grad = False  # requires_grad
        self.bias.requires_grad = requires_grad

    @torch.no_grad()
    def k_preprocess(self):
        self._change_params_requires_grad(False)
        # k prepro
        # k -> aux_U @ s
        # self.k = nn.Parameter(self.u @ self.s, requires_grad=True)
        self.k.set_(self.u @ self.s)
        # TODO: !! sorting s at the top of the training loop might fuck everything up
        # self.s.set_(torch.sort(self.s, descending=True).values)

        self.k.requires_grad = True

    @torch.no_grad()
    def l_preprocess(self):
        self._change_params_requires_grad(False)
        # lt -> s @ aux_Vt
        # NOTE transposed from paper!
        self.lt.set_(self.s @ self.vt)
        self.lt.requires_grad = True

    @torch.no_grad()
    def k_postprocess(self):
        # NOTE must run after 'l' forward step b/c self.u is used in the l forward step
        self._change_params_requires_grad(False)
        # aux_Unp1 -> q from qr(k)
        #   aux_Unp1 used in s-step forward, can keep in u
        self.unp1.set_(torch.linalg.qr(self.k)[0])
        # aux_N -> aux_Unp1.T @ aux_U
        #   used in setting s,
        self.n.set_(self.unp1.T @ self.u)

    @torch.no_grad()
    def l_postprocess(self):
        self._change_params_requires_grad(False)
        # aux_Vtnp1 -> q from qr(lt.T)
        self.vtnp1.set_(torch.linalg.qr(self.lt.T)[0].T)
        # aux_M -> aux_Vtnp1 @ aux_Vt.T
        self.m.set_(self.vtnp1 @ self.vt.T)

    @torch.no_grad()
    def s_preprocess(self):
        self._change_params_requires_grad(False)
        if self.bias is not None:
            self.bias.requires_grad = True
        # set aux_U -> aux_Unp1  # done above
        # set aux_Vt -> aux_Vtnp1  # done previously now
        # set s -> (aux_N @ s) @ aux_M.T
        # self.s = nn.Parameter(self.n @ self.s @ self.m.T, requires_grad=True)
        self.s.set_(self.n @ self.s @ self.m.T)
        self.s.requires_grad = True
        # overwrite the old vars once there is are new ones
        self.u.set_(self.unp1.data)
        self.vt.set_(self.vtnp1.data)


class DLRTLinearAdaptive(DLRTModule):
    # overwrite the original layer depending on its type?
    __constants__ = ["in_features", "out_features"]
    in_features: int
    out_features: int
    weight: Tensor

    def __init__(
            self,
            in_features: int,
            out_features: int,
            low_rank_percent: float = None,
            # rmax: int = None,
            bias: bool = True,
            tau: float = 0.1,
            device=None,
            dtype=None,
            pretrain: bool = True,
    ) -> None:
        """
        TODO: this
        TODO: this layer has some elements which are transposed from the paper, need to mark and
            double check
        Parameters
        ----------
        in_features
        out_features
        low_rank_percent
            starting inner rank
        rmax
            max number of ranks (percentage)
        bias
        eps_adapt
            epsilon to use in adaptive methods.
        init_method
            init with svd or with random for K, L.T, and S
        device
        dtype
        """
        super().__init__()
        factory_kwargs = {"device": device, "dtype": dtype}
        self.in_features = in_features
        self.out_features = out_features
        if (isinstance(bias, bool) and bias) or bias is not None:
            self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
        else:
            self.register_parameter("bias", None)

        if low_rank_percent is None:
            # set the max low_rank to be such that the
            roots = np.roots([1, in_features + out_features, in_features * out_features])
            pos_coeff = roots[roots > 0]  # TODO: adjust factor?
            if len(pos_coeff) < 1:
                self.rmax = min([in_features, out_features]) // 2
            else:
                self.rmax = int(np.floor(pos_coeff[-1]))
            # set the initial low_rank to be most of the rmax
            if self.rmax < 10:
                self.rmax = 20
            self.low_rank = self.rmax // 2
        else:
            self.rmax = min([in_features, out_features]) // 2
            self.low_rank = int(self.rmax * low_rank_percent)
            self.rmax = int(self.low_rank * 2)  # TODO: cleanup?

        # new added part
        self.dims = [self.out_features, self.in_features]
        self.rmax = min(self.dims) //2
        self.device = device
        self.dtype  = dtype
        # print(f'rmax {self.rank}')
        self.rank = self.rmax  # [min([r,rmax_r]) for r,rmax_r in zip(self.rmax,self.dims)]
        self.dynamic_rank = self.rank  # [min(self.rank)]*2#self.rank

        self.basic_number_weights = out_features * in_features

        # self.eps_adapt = eps_adapt
        self.tau = tau

        # int(low_rank_percent * min([in_features, out_features]))
        # self.rmax = min(rmax, int(min([in_features, out_features]) / 2))
        # self.low_rank = low_rank if low_rank is not None else min([in_features, out_features])

        self.dlrt = True

        self.pretrain = False  # pretrain
        if pretrain:
            self.fullweight = nn.Parameter(
                torch.empty(out_features, in_features),
                requires_grad=True,
            )

        # need k, lt, s, bias
        # K -> U @ S, L -> V @ S.T
        _, s_ordered, _ = torch.linalg.svd(torch.diag(torch.abs(torch.randn(2 * self.rank))))
        U = torch.randn(self.out_features, self.rmax)
        V = torch.randn(self.in_features, self.rmax)
        U, _, _ = torch.linalg.svd(U)
        V, _, _ = torch.linalg.svd(V)
        self.U = torch.nn.Parameter(torch.randn(self.out_features, self.rmax).to(device), requires_grad=False)
        self.S_hat = torch.nn.Parameter(torch.diag(s_ordered).to(device))
        self.V = torch.nn.Parameter(torch.randn(self.in_features, self.rmax).to(device), requires_grad=False)
        self.U_hat = torch.nn.Parameter(torch.randn(self.out_features, 2 * self.rmax).to(device), requires_grad=False)
        self.V_hat = torch.nn.Parameter(torch.randn(self.in_features, 2 * self.rmax).to(device), requires_grad=False)
        self.K = torch.nn.Parameter(torch.randn(self.out_features, self.rmax).to(device))
        self.L = torch.nn.Parameter(torch.randn(self.in_features, self.rmax).to(device))
        self.N_hat = torch.nn.Parameter(torch.randn(2 * self.rmax, self.rmax).to(device), requires_grad=False)
        self.M_hat = torch.nn.Parameter(torch.randn(2 * self.rmax, self.rmax).to(device), requires_grad=False)

        self.reset_parameters()
        # self.train_case = "k"

    def extra_repr(self) -> str:
        return (
            f"in_features={self.in_features}, low_rank={self.low_rank}, "
            f"out_features={self.out_features}, bias={self.bias is not None}"
        )

    def get_classic_weight_repr(self):
        return self.k @ self.s @ self.lt

    def reset_parameters(self) -> None:
        # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
        # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
        # https://github.com/pytorch/pytorch/issues/57109
        nn.init.kaiming_uniform_(self.S_hat, a=math.sqrt(5))

        nn.init.kaiming_uniform_(self.U, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.V, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.U_hat, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.V_hat, a=math.sqrt(5))
        # nn.init.kaiming_uniform_(self.vtnp1, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.K, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.L, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.M_hat, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.N_hat, a=math.sqrt(5))

        if self.bias is not None:
            w = self.U @ self.S_hat[:self.dynamic_rank, :self.dynamic_rank] @ (
                self.V).T  # torch.einsum('ij,ai,bj->ab',self.C[:self.dynamic_rank[0],:self.dynamic_rank[1]],*self.Us)
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(w)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)

    def _change_params_requires_grad(self, requires_grad):
        self.k.requires_grad = requires_grad
        self.s.requires_grad = requires_grad
        self.lt.requires_grad = requires_grad
        self.u.requires_grad = False  # requires_grad
        self.unp1.requires_grad = False  # requires_grad
        self.vt.requires_grad = False  # requires_grad
        self.vtnp1.requires_grad = False  # requires_grad
        self.n.requires_grad = False  # requires_grad
        self.m.requires_grad = False  # requires_grad
        self.bias.requires_grad = requires_grad

    def change_training_case(self, case):
        # switch -> if current train case is k/l, do post for
        self.train_case = case

    def forward(self, input: Tensor):

        if self.step == 1:
            self.K.requires_grad = True
            K, V = self.K[:, :self.dynamic_rank], self.V[:, :self.dynamic_rank]
            x = F.linear(input, V.T)
            x = F.linear(x, K, self.bias)

        elif self.step == 2:
            self.L.requires_grad = True
            L, U = self.L[:, :self.dynamic_rank], self.U[:, :self.dynamic_rank]
            x = F.linear(input, L.T)
            x = F.linear(x, U, self.bias)

        elif self.step == 3:

            K, V = self.K[:, :self.dynamic_rank], self.V[:, :self.dynamic_rank]
            # self.K.requires_grad = False
            x = F.linear(input, V.T)
            x = F.linear(x, K, self.bias)

        elif self.step == 4:

            # provare a tenere i gradienti
            L, U = self.L[:, :self.dynamic_rank], self.U[:, :self.dynamic_rank]
            # self.L.requires_grad = False
            x = F.linear(input, L.T)
            x = F.linear(x, U, self.bias)

        elif self.step == 'core':

            S_hat, U_hat, V_hat = self.S_hat[:2 * self.dynamic_rank, :2 * self.dynamic_rank], self.U_hat[:,
                                                                                              :2 * self.dynamic_rank], self.V_hat[
                                                                                                                       :,
                                                                                                                       :2 * self.dynamic_rank]
            x = F.linear(input, V_hat.T)
            x = F.linear(x, S_hat)
            x = F.linear(x, U_hat, self.bias)

        elif self.step == 'test':

            S_hat, U_hat, V_hat = self.S_hat[:self.dynamic_rank, :self.dynamic_rank], self.U[:,
                                                                                      :self.dynamic_rank], self.V[:,
                                                                                                           :self.dynamic_rank]
            x = F.linear(input, V_hat.T)
            x = F.linear(x, S_hat)
            x = F.linear(x, U_hat, self.bias)

        return x

    @torch.no_grad()
    def K_preprocess_step(self):

        K = self.U[:, :self.dynamic_rank] @ self.S_hat[:self.dynamic_rank, :self.dynamic_rank]
        self.K[:, :self.dynamic_rank] = K

    @torch.no_grad()
    def L_preprocess_step(self):

        L = self.V[:, :self.dynamic_rank] @ self.S_hat[:self.dynamic_rank, :self.dynamic_rank].T
        self.L[:, :self.dynamic_rank] = L

    @torch.no_grad()
    def S_preprocess_step(self):

        s = self.M_hat[:2 * self.dynamic_rank, :self.dynamic_rank] @ self.S_hat[: self.dynamic_rank,
                                                                     :self.dynamic_rank] @ self.N_hat[
                                                                                           :2 * self.dynamic_rank,
                                                                                           :self.dynamic_rank].T
        self.S_hat[:2 * self.dynamic_rank, :2 * self.dynamic_rank] = s

    @torch.no_grad()
    def K_postprocess_step(self):

        U_hat = torch.hstack((self.K[:, :self.dynamic_rank], self.U[:, :self.dynamic_rank]))

        try:
            U_hat, _ = torch.linalg.qr(U_hat)
        except:
            U_hat, _ = np.linalg.qr(U_hat)
            U_hat = torch.tensor(U_hat)
        self.U_hat[:, :2 * self.dynamic_rank] = U_hat
        self.M_hat[:2 * self.dynamic_rank, :self.dynamic_rank] = self.U_hat[:, :2 * self.dynamic_rank].T @ self.U[:,
                                                                                                           :self.dynamic_rank]

    @torch.no_grad()
    def L_postprocess_step(self):

        V_hat = torch.hstack((self.L[:, :self.dynamic_rank], self.V[:, :self.dynamic_rank]))
        try:
            V_hat, _ = torch.linalg.qr(V_hat)
        except:
            V_hat, _ = np.linalg.qr(V_hat.detach().numpy())
            V_hat = torch.tensor(V_hat)
        self.V_hat[:, :2 * self.dynamic_rank] = V_hat
        self.N_hat[:2 * self.dynamic_rank, :self.dynamic_rank] = self.V_hat[:, :2 * self.dynamic_rank].T @ self.V[:,
                                                                                                           :self.dynamic_rank]

    @torch.no_grad()
    def S_postprocess_step(self):

        s_small = torch.clone(self.S_hat[:2 * self.dynamic_rank, :2 * self.dynamic_rank])
        # print(s_small)
        try:
            u2, d, v2 = torch.linalg.svd(s_small)
        except Exception as e:
            print(e)
            print(s_small)
            u2, d, v2 = np.linalg.svd(s_small)

        tmp = 0.0
        tol = self.tau * torch.linalg.norm(d)
        rmax = int(np.floor(d.shape[0] / 2))
        for j in range(0, 2 * rmax - 1):
            tmp = torch.linalg.norm(d[j:2 * rmax - 1])
            if tmp < tol:
                rmax = j
                break

        rmax = min([rmax, self.rmax])
        rmax = max([rmax, 2])

        self.S_hat[:rmax, :rmax] = torch.diag(d[:rmax])
        self.U[:, :rmax] = self.U_hat[:, :2 * self.dynamic_rank] @ u2[:, :rmax]
        self.V[:, :rmax] = self.V_hat[:, :2 * self.dynamic_rank] @ (v2[:, :rmax])
        self.dynamic_rank = int(rmax)

    @torch.no_grad()
    def update_V(self):  ## TODO: To fix

        pass

    @torch.no_grad()
    def update_Q(self):  ## TODO: To fix

        pass
